#!/usr/bin/python
# -*- coding:utf-8 -*-
from .dataset_wrapper import MixDatasetWrapper
from .codesign import CoDesignDataset
from .confidence import ConfidenceDataset, BalancedConfidenceDataset
from .resample import ClusterResampler
from .mimicry import MimicryDataset
from .extend import ExtendDataset


import torch
from torch.utils.data import DataLoader

import utils.register as R
from utils.logger import print_log

def create_dataset(config: dict):
    splits = []
    for split_name in ['train', 'valid', 'test']:
        split_config = config.get(split_name, None)
        if split_config is None:
            splits.append(None)
            continue
        if isinstance(split_config, list):
            dataset = MixDatasetWrapper(
                *[R.construct(cfg) for cfg in split_config]
            )
        else:
            dataset = R.construct(split_config)
        splits.append(dataset)
    return splits  # train/valid/test


def create_dataloader(dataset, config: dict, n_gpu: int=1, validation: bool=False):
    if 'wrapper' in config:
        dataset = R.construct(config['wrapper'], dataset=dataset)
    batch_size = config.get('batch_size', n_gpu) # default 1 on each gpu
    if validation:
        batch_size = config.get('val_batch_size', batch_size)
    shuffle = config.get('shuffle', False)
    num_workers = config.get('num_workers', 4)
    collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None
    if n_gpu > 1:
        sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
        batch_size = int(batch_size / n_gpu)
        print_log(f'Batch size on a single GPU: {batch_size}')
    else:
        sampler = None
    return DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=(shuffle and sampler is None),
        collate_fn=collate_fn,
        sampler=sampler
    )